{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.3 参数调优 - K折交叉验证 & GridSearch网格搜索"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**前情提要 - 5.2节的模型搭建代码**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeClassifier(max_depth=3, random_state=123)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">DecisionTreeClassifier</label><div class=\"sk-toggleable__content\"><pre>DecisionTreeClassifier(max_depth=3, random_state=123)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "DecisionTreeClassifier(max_depth=3, random_state=123)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1.读取数据与简单预处理\n",
    "import pandas as pd\n",
    "df = pd.read_excel('员工离职预测模型.xlsx')\n",
    "df = df.replace({'工资': {'低': 0, '中': 1, '高': 2}})\n",
    "\n",
    "# 2.提取特征变量和目标变量\n",
    "X = df.drop(columns='离职') \n",
    "y = df['离职']\n",
    "\n",
    "# 3.划分训练集和测试集\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)\n",
    "\n",
    "# 4.模型训练及搭建\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "model = DecisionTreeClassifier(max_depth=3, random_state=123)\n",
    "model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**5.3.1 K折交叉验证**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.96666667, 0.96066667, 0.959     , 0.96233333, 0.91366667])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import cross_val_score # 交叉验证函数\n",
    "acc = cross_val_score(model, X, y, cv=5)  # 默认是准确率\n",
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9524666666666667"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.97146884, 0.9674637 , 0.96641351, 0.97047305, 0.95030156])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import cross_val_score\n",
    "acc = cross_val_score(model, X, y, scoring='roc_auc', cv=5)  # roc_auc值\n",
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9652241309284616"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**5.3.2 GridSearch网格搜索**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**1.单参数的参数调优**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'max_depth': 7}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV  # 网格搜索合适的超参数\n",
    "\n",
    "# 指定参数k的范围\n",
    "parameters = {'max_depth': [3, 5, 7, 9, 11]}\n",
    "# 构建决策树分类器\n",
    "model = DecisionTreeClassifier()  # 这里因为要进行参数调优，所以不需要传入固定的参数了\n",
    "\n",
    "# 网格搜索\n",
    "grid_search = GridSearchCV(model, parameters, scoring='roc_auc', cv=5)   # cv=5表示交叉验证5次，默认值为3；scoring='roc_auc'表示通过ROC曲线的AUC值来进行评分，默认通过准确度评分\n",
    "grid_search.fit(X_train, y_train) # 传入测试集数据并开始进行参数调优\n",
    "\n",
    "# 输出参数的最优值\n",
    "grid_search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 通过取消如下代码的注释可以查看GridSearchCV函数的官方介绍\n",
    "# GridSearchCV?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**补充知识点：批量生成调参所需数据**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "parameters = {'max_depth': np.arange(1, 10, 2)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**2.参数调优的效果检验**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据max_depth=7来重新搭建模型，并进行检测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.1 查看新模型准确度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.982\n"
     ]
    }
   ],
   "source": [
    "# 根据max_depth=7来重新搭建模型\n",
    "model = DecisionTreeClassifier(max_depth=7)  # 这个max_depth参数是可以调节的，之后讲\n",
    "model.fit(X_train, y_train) \n",
    "\n",
    "# 查看整体预测准确度\n",
    "y_pred = model.predict(X_test)\n",
    "from sklearn.metrics import accuracy_score\n",
    "score = accuracy_score(y_pred, y_test)\n",
    "print(score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原来准确度评分score为0.957，现在为0.982，的确有所提升"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.2 查看新模型的ROC曲线和AUC值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlWElEQVR4nO3df3DU1f3v8dcmm2zCj10ENAQSYrBaUSrWzQUJX6aj1Tjo4PXWXnLHjqiFjhltEVLtmNKrhfF+c+0PBn8FtIKOM2gz/uq0M6mS770tRKE/SEOv38J3tEANSGK+iSUbfiVk99w/kl2y7AbyWZIcNp/nY2Yn5Ozns3v2TPS89v35cTzGGCMAAABLMmx3AAAAuBthBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVXtsdGIpIJKIjR45o4sSJ8ng8trsDAACGwBijrq4uTZ8+XRkZg9c/0iKMHDlyRIWFhba7AQAAUnDo0CEVFBQM+nxahJGJEydK6vswfr/fcm8AAMBQhEIhFRYWxubxwaRFGIkemvH7/YQRAADSzPlOseAEVgAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGCV4zCyY8cOLVmyRNOnT5fH49GvfvWr8+6zfft2BYNB5eTkaNasWdq0aVMqfQUAAGOQ4zBy/PhxzZ07V88///yQtj948KBuv/12LVq0SE1NTfrhD3+olStX6u2333bcWQAAMPY4Xptm8eLFWrx48ZC337Rpk2bOnKkNGzZIkmbPnq3du3frZz/7me6++26nbw8AAMaYEV8ob9euXSorK4tru+2227R582adPn1aWVlZCft0d3eru7s79nsoFBrpbgIALnKRiFHYGIUjRpHoz4gS2uKeN0bhiNQbiSTdNvE1lWT/s543/fsNfD7WNuD5s/a92N19Q4HmzAhYee8RDyOtra3Ky8uLa8vLy1Nvb6/a29uVn5+fsE91dbXWrl070l0DgBFjjFHEaJAJ66yJaogT6mCvlWySHDghnmnrm9B7B3mfgRPqYH2ODHids99nKCEhsW9K/joDn+9vw8j66sxLxm4YkRKXDjb9CXGwJYWrqqpUWVkZ+z0UCqmwsHDkOghApv9/+L1DmFCGOkkOnEzOPXmdmSSdfOtM/pqKfz6h/0r41tsbMUn2GeR1Bj4/6LfrviCC0ZWZ4VGmx6OMDPX/9Axo88ib4VGGp78tw6MMj/p/DmzzxL9OQlvfz8yM6L+VpG3ga0mZGRnK8EiDTHkXjSsvm2DtvUc8jEybNk2tra1xbW1tbfJ6vZoyZUrSfXw+n3w+30h3DS5ijNGx7l51HOvRqd7woBNt8pJtxNG3zsRvdomTV/IycPy30XDY+bfOgRPh4O8z4PkBr58GVeQxxeNRwuQWPyEmn1DPtCmx7axJcuBEmnySPNckO7BNSfY/6/kkE/qZNsVN6LHtkrzmwG2dhgSkrxEPIwsWLNBvfvObuLZt27appKQk6fkiwFB194b1xfEedRzrUfuxbnUc61HH8e7+38/8u+NYt9qP96inN2K7y2ntfN86z0xkiZNsxjkmycTJK/H5c33rTD6hDXh+0AktcZKMvVbmhU2SicEhfuwyMzyDVoYBN3IcRo4dO6a///3vsd8PHjyoPXv2aPLkyZo5c6aqqqr02Wef6bXXXpMkVVRU6Pnnn1dlZaW+853vaNeuXdq8ebPeeOON4fsUGBMiEaOjJ0/3hYckYaIjFjj6wkfXqV7H7zEuO1PjsjMHmVScl2wH/4Z31rfSJJPk4BPWgAl7kG+t8eXmJBPqECbJ+NfXOSZ8T3+JmckTwMhwHEZ2796tm266KfZ79NyO++67T6+++qpaWlrU3Nwce764uFh1dXVavXq1XnjhBU2fPl3PPvssl/W6xIme3oTKRfuxngH/PlPJ+OeJHscnqXkzPJoyIVtTxvs0ZUK2pk7wacr4bE2ZEP39zHNTxvuUm505Qp8UAJAqjzEX/5HiUCikQCCgzs5O+f1+291xvWPdvWrtPKnWzr4w0X6sWx0DKhcDqxgnT4cdv34gN6svSERDRH+QmDqhP2T0h41LJ/jkz/XyjR0ALlJDnb9H5WoapAdjjDpPnlZL5ym1dp7q/3my72eor62185S6up0dHvF5MzR1QmKYmDogaESrGpeMy1a2lyWTAMBNCCMuEYkYtR/vjoWMz0OnBoSOk7H27iGe5Dkxx6tp/hxd5vcNfoikv31cdibVCwDAoAgjY0BvOKK2ru6EcBGtZkTDR+8Qz8eYMj5bef4c5QdyNC0Q/Zmr/ECO8vx9bRN8/OkAAIYHM0oaaO44oZ372xUxUujU6YRqRvux7iHdYCnDI1060dcXLPwDg0aOpvlzlB/I1WV+n3KyOMkTADB6CCMXuc4Tp/XfX9ypz0Pd59wuK9MzoJpxpooxsLpx6QSfvJmcjwEAuLgQRi4CoVOn9e+HO9X46T/V8Em79rWGpP5Kx+lIRKdOR5Tn9+m6gkma6PMmPXQyZXw2dyAEAKQlwsgo6+4Na19Ll/566Kj+evio/nroqPb/5/Fz7pPtzVDNt4IKFl0ySr0EAGD0EEZGUDhidOA/j2lPf/D4f4c7ta8lpNPhxBM8ZkzK1fWFk3TjFVP0Xy6/RD7vmfM2JuVm6ZLx2aPZdQAARg1hZJgYY3Sk81RfxaM/fHx0uFPHexJv+jV5fLauKwhobsEkzS0M6LqCSZo6gYUBAQDuRBgZgs6Tp3W8u1dHT5zW0RM9OnrydN+/T/bo6InT2t92TH89fFTtx3oS9s3NytRXZgRioeP6wkkquCSX+24AANCPMDJA9H4dXxzv0V+a/6nGT/seh/95ckj7ezM8ujp/Yl/oKJik6woD+tKlE7iCBQCAc3B9GOkNR/TM//lEh744oV0HOga9hDYr06NAbrYmjcvSpNwsTRp35t8zLsnVdQWTdO10P/foAADAIdeHkT/94ws993//HtcWyM3S3MJJCs68RMGiSzS3MKCJOVmWeggAwNjm+jByovvMCab/+xtf0X+9fgbLzAMAMIpcH0bCpu8y2xtmTtL/mDfTcm8AAHAf159ZGe5f1MWb4fqhAADACtfPwNEwQhYBAMAO10/BVEYAALDL9TPwmcoINyEDAMAGwkisMkIYAQDABsJI/9U0GdyeHQAAK1wfRnqpjAAAYJXrw0ikP4xkEkYAALDC9WGklzACAIBVrg8jVEYAALDL9WEkegIrYQQAADsII9HKCFfTAABgBWEkGkYyCSMAANjg+jDSS2UEAACrXB9GOIEVAAC7XB9GuLQXAAC7XB9GIoY7sAIAYJPrw0hvmFV7AQCwyfVhhMoIAAB2uT6M9EYikli1FwAAW1wfRsJ9WYTKCAAAlhBGopURwggAAFYQRqiMAABgFWGkvzLCfUYAALCDMNJ3MQ1hBAAASwgjVEYAALCKMMLt4AEAsIowwqq9AABYRRihMgIAgFWuDyOs2gsAgF2uDyPRtWkIIwAA2OH6MBJdtZcwAgCAHa4PI6zaCwCAXa4PI9FzRli1FwAAO1wfRiL9YcSbSRgBAMAG14cRKiMAANjl+jASvc+IN8P1QwEAgBWun4GjYYQsAgCAHa6fgsOGyggAADa5fgY+czt4yx0BAMClXD8Fnwkjrh8KAACscP0MzKq9AADYlVIYqampUXFxsXJychQMBtXQ0HDO7bdu3aq5c+dq3Lhxys/P1wMPPKCOjo6UOjzcWLUXAAC7HIeR2tparVq1SmvWrFFTU5MWLVqkxYsXq7m5Oen2H3zwgZYtW6bly5frb3/7m9588039+c9/1ooVKy6488OBMAIAgF2Ow8j69eu1fPlyrVixQrNnz9aGDRtUWFiojRs3Jt3+D3/4gy6//HKtXLlSxcXF+pd/+Rc9+OCD2r179wV3fjiEWbUXAACrHIWRnp4eNTY2qqysLK69rKxMO3fuTLpPaWmpDh8+rLq6Ohlj9Pnnn+utt97SHXfcMej7dHd3KxQKxT1GSphVewEAsMpRGGlvb1c4HFZeXl5ce15enlpbW5PuU1paqq1bt6q8vFzZ2dmaNm2aJk2apOeee27Q96murlYgEIg9CgsLnXTTkTCr9gIAYFVKJ7B6zrryxBiT0Ba1d+9erVy5Uk888YQaGxv13nvv6eDBg6qoqBj09auqqtTZ2Rl7HDp0KJVuDklsbRrCCAAAVnidbDx16lRlZmYmVEHa2toSqiVR1dXVWrhwoR577DFJ0nXXXafx48dr0aJFeuqpp5Sfn5+wj8/nk8/nc9K1lMVW7SWMAABghaPKSHZ2toLBoOrr6+Pa6+vrVVpamnSfEydOKOOsG4plZmZK6quo2MaqvQAA2OX4ME1lZaVefvllbdmyRfv27dPq1avV3NwcO+xSVVWlZcuWxbZfsmSJ3nnnHW3cuFEHDhzQhx9+qJUrV2revHmaPn368H2SFESrIhKVEQAAbHF0mEaSysvL1dHRoXXr1qmlpUVz5sxRXV2dioqKJEktLS1x9xy5//771dXVpeeff17f//73NWnSJN188816+umnh+9TpKh3QBjhnBEAAOzwmIvhWMl5hEIhBQIBdXZ2yu/3D9vrnjod1tX/8z1J0r+vvU0TfI6zGQAAGMRQ529Xr03Ty2EaAACsc3UYCQ88TMMJrAAAWEEY6UdlBAAAOwgj/TiBFQAAOwgjoioCAIBN7g4jhlvBAwBgm7vDSJjKCAAAtrk7jPRXRjK5kgYAAGvcHUb6zxnJzCSMAABgC2FEVEYAALCJMCIpk3NGAACwhjAiwggAADa5O4wYwggAALa5O4xEIpIIIwAA2OTyMNL3kzACAIA9rg4jvdHKCFfTAABgjavDSITKCAAA1rk6jPRyzggAANa5OoxEDGvTAABgm6vDSG+YVXsBALDN1WGEyggAAPa5Ooz09t+BNYOraQAAsMbVYSR6O3gvq/YCAGANYURURgAAsIkwIs4ZAQDAJsKIuM8IAAA2uTuMsGovAADWuTuMUBkBAMA6woikzAxXDwMAAFa5ehaOhREKIwAAWOPqMBK9AyuX9gIAYI+rw0gMWQQAAGsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKtSCiM1NTUqLi5WTk6OgsGgGhoazrl9d3e31qxZo6KiIvl8Pl1xxRXasmVLSh0GAABji9fpDrW1tVq1apVqamq0cOFCvfjii1q8eLH27t2rmTNnJt1n6dKl+vzzz7V582Z96UtfUltbm3p7ey+48wAAIP05DiPr16/X8uXLtWLFCknShg0b9P7772vjxo2qrq5O2P69997T9u3bdeDAAU2ePFmSdPnll19YrwEAwJjh6DBNT0+PGhsbVVZWFtdeVlamnTt3Jt3n17/+tUpKSvSTn/xEM2bM0FVXXaVHH31UJ0+eHPR9uru7FQqF4h4AAGBsclQZaW9vVzgcVl5eXlx7Xl6eWltbk+5z4MABffDBB8rJydG7776r9vZ2PfTQQ/riiy8GPW+kurpaa9euddI1AACQplI6gdXj8cT9boxJaIuKRCLyeDzaunWr5s2bp9tvv13r16/Xq6++Omh1pKqqSp2dnbHHoUOHUukmAABIA44qI1OnTlVmZmZCFaStrS2hWhKVn5+vGTNmKBAIxNpmz54tY4wOHz6sK6+8MmEfn88nn8/npGsAACBNOaqMZGdnKxgMqr6+Pq69vr5epaWlSfdZuHChjhw5omPHjsXaPv74Y2VkZKigoCCFLgMAgLHE8WGayspKvfzyy9qyZYv27dun1atXq7m5WRUVFZL6DrEsW7Ystv0999yjKVOm6IEHHtDevXu1Y8cOPfbYY/r2t7+t3Nzc4fskAAAgLTm+tLe8vFwdHR1at26dWlpaNGfOHNXV1amoqEiS1NLSoubm5tj2EyZMUH19vb73ve+ppKREU6ZM0dKlS/XUU08N36cAAABpy2OMMbY7cT6hUEiBQECdnZ3y+/3D9rov7divf637D33jhhlav/T6YXtdAAAw9PmbtWkAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVqUURmpqalRcXKycnBwFg0E1NDQMab8PP/xQXq9X119/fSpvCwAAxiDHYaS2tlarVq3SmjVr1NTUpEWLFmnx4sVqbm4+536dnZ1atmyZvv71r6fcWQAAMPY4DiPr16/X8uXLtWLFCs2ePVsbNmxQYWGhNm7ceM79HnzwQd1zzz1asGBByp0FAABjj6Mw0tPTo8bGRpWVlcW1l5WVaefOnYPu98orr2j//v168sknh/Q+3d3dCoVCcQ8AADA2OQoj7e3tCofDysvLi2vPy8tTa2tr0n0++eQTPf7449q6dau8Xu+Q3qe6ulqBQCD2KCwsdNJNAACQRlI6gdXj8cT9boxJaJOkcDise+65R2vXrtVVV1015NevqqpSZ2dn7HHo0KFUugkAANLA0EoV/aZOnarMzMyEKkhbW1tCtUSSurq6tHv3bjU1Nem73/2uJCkSicgYI6/Xq23btunmm29O2M/n88nn8znpGgAASFOOKiPZ2dkKBoOqr6+Pa6+vr1dpaWnC9n6/Xx999JH27NkTe1RUVOjLX/6y9uzZo/nz519Y7wEAQNpzVBmRpMrKSt17770qKSnRggUL9NJLL6m5uVkVFRWS+g6xfPbZZ3rttdeUkZGhOXPmxO1/2WWXKScnJ6EdAAC4k+MwUl5ero6ODq1bt04tLS2aM2eO6urqVFRUJElqaWk57z1HAAAAojzGGGO7E+cTCoUUCATU2dkpv98/bK/70o79+te6/9A3bpih9UuvH7bXBQAAQ5+/WZsGAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFjl6jBijO0eAAAAd4eR/p8eeaz2AwAAN3N1GInykEUAALDG1WGEwzQAANjn6jASRWEEAAB7XB1GjCiNAABgm7vDSH8W4ZwRAADscXUYAQAA9hFGxKW9AADYRBgRh2kAALDJ1WHEcG0vAADWuTqMRFEZAQDAHleHEQojAADY5+4wEvsXpREAAGxxdRgBAAD2pRRGampqVFxcrJycHAWDQTU0NAy67TvvvKNbb71Vl156qfx+vxYsWKD3338/5Q4PJ256BgCAfY7DSG1trVatWqU1a9aoqalJixYt0uLFi9Xc3Jx0+x07dujWW29VXV2dGhsbddNNN2nJkiVqamq64M4PF7IIAAD2eIzD61vnz5+vG264QRs3boy1zZ49W3fddZeqq6uH9BrXXnutysvL9cQTTwxp+1AopEAgoM7OTvn9fifdPacN//axNvzbJ/rW/Jn6X//tK8P2ugAAYOjzt6PKSE9PjxobG1VWVhbXXlZWpp07dw7pNSKRiLq6ujR58uRBt+nu7lYoFIp7jAQO0wAAYJ+jMNLe3q5wOKy8vLy49ry8PLW2tg7pNX7+85/r+PHjWrp06aDbVFdXKxAIxB6FhYVOugkAANJISiewes4qJRhjEtqSeeONN/TjH/9YtbW1uuyyywbdrqqqSp2dnbHHoUOHUunmeUWPT7E2DQAA9nidbDx16lRlZmYmVEHa2toSqiVnq62t1fLly/Xmm2/qlltuOee2Pp9PPp/PSdcAAECaclQZyc7OVjAYVH19fVx7fX29SktLB93vjTfe0P3336/XX39dd9xxR2o9HQn9J41wzggAAPY4qoxIUmVlpe69916VlJRowYIFeumll9Tc3KyKigpJfYdYPvvsM7322muS+oLIsmXL9Mwzz+jGG2+MVVVyc3MVCASG8aOkjiwCAIA9jsNIeXm5Ojo6tG7dOrW0tGjOnDmqq6tTUVGRJKmlpSXuniMvvviient79fDDD+vhhx+Otd9333169dVXL/wTXACWpgEAwD7HYUSSHnroIT300ENJnzs7YPz+979P5S1GxZlLe6mNAABgC2vTAAAAq1wdRgwHagAAsM7VYQQAANjn6jDC7eABALDP3WGk/yd3YAUAwB5XhxEAAGCfq8MIh2kAALDP1WEEAADY5+owEr20l8IIAAD2uDqMAAAA+9wdRjhnBAAA61wdRmKX9pJGAACwxtVhBAAA2OfqMGIMJ7ACAGCbq8MIAACwz9VhxJy5HzwAALDE3WHEdgcAAIC7w0gUC+UBAGCPq8MIa9MAAGCfq8MIAACwz9VhhLVpAACwz9VhBAAA2OfqMMI5IwAA2OfqMAIAAOwjjIhLewEAsMnVYSS2Ng1ZBAAAa1wdRgAAgH2uDiMsTQMAgH3uDiMsTgMAgHWuDiMxnDQCAIA1rg4jhnV7AQCwztVhJIq6CAAA9rg6jHAHVgAA7HN1GAEAAPa5OoycubSX0ggAALa4O4xw/ioAANa5OoxEcc4IAAD2uDyMUBoBAMA2l4eRPhRGAACwx9VhhEt7AQCwjzACAACscnUYifJQGgEAwBpXhxHWpgEAwD5XhxEAAGCfq8MI54wAAGCfq8NIFKeMAABgj6vDCGvTAABgn7vDCIdpAACwztVhJIrDNAAA2OPqMMKlvQAA2OfqMBJFYQQAAHvcHUYojAAAYJ2rw0jsahpKIwAAWOPqMBLFpb0AANjj6jBiuLYXAADrXB1GojhMAwCAPSmFkZqaGhUXFysnJ0fBYFANDQ3n3H779u0KBoPKycnRrFmztGnTppQ6O9yoiwAAYJ/jMFJbW6tVq1ZpzZo1ampq0qJFi7R48WI1Nzcn3f7gwYO6/fbbtWjRIjU1NemHP/yhVq5cqbfffvuCOw8AANKf4zCyfv16LV++XCtWrNDs2bO1YcMGFRYWauPGjUm337Rpk2bOnKkNGzZo9uzZWrFihb797W/rZz/72QV3/kJxyggAAPY5CiM9PT1qbGxUWVlZXHtZWZl27tyZdJ9du3YlbH/bbbdp9+7dOn36dNJ9uru7FQqF4h4j4cylvZw0AgCALY7CSHt7u8LhsPLy8uLa8/Ly1NramnSf1tbWpNv39vaqvb096T7V1dUKBAKxR2FhoZNuOkYUAQDAnpROYD27kmCMOWd1Idn2ydqjqqqq1NnZGXscOnQolW6eV9k1eXr4pis0tzAwIq8PAADOz+tk46lTpyozMzOhCtLW1pZQ/YiaNm1a0u29Xq+mTJmSdB+fzyefz+ekaylZMne6lsydPuLvAwAABueoMpKdna1gMKj6+vq49vr6epWWlibdZ8GCBQnbb9u2TSUlJcrKynLYXQAAMNY4PkxTWVmpl19+WVu2bNG+ffu0evVqNTc3q6KiQlLfIZZly5bFtq+oqNCnn36qyspK7du3T1u2bNHmzZv16KOPDt+nAAAAacvRYRpJKi8vV0dHh9atW6eWlhbNmTNHdXV1KioqkiS1tLTE3XOkuLhYdXV1Wr16tV544QVNnz5dzz77rO6+++7h+xQAACBteUwaLNASCoUUCATU2dkpv99vuzsAAGAIhjp/szYNAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsMrx7eBtiN4kNhQKWe4JAAAYqui8fb6bvadFGOnq6pIkFRYWWu4JAABwqqurS4FAYNDn02JtmkgkoiNHjmjixInyeDzD9rqhUEiFhYU6dOgQa96MMMZ6dDDOo4NxHh2M8+gYyXE2xqirq0vTp09XRsbgZ4akRWUkIyNDBQUFI/b6fr+fP/RRwliPDsZ5dDDOo4NxHh0jNc7nqohEcQIrAACwijACAACscnUY8fl8evLJJ+Xz+Wx3ZcxjrEcH4zw6GOfRwTiPjothnNPiBFYAADB2uboyAgAA7COMAAAAqwgjAADAKsIIAACwasyHkZqaGhUXFysnJ0fBYFANDQ3n3H779u0KBoPKycnRrFmztGnTplHqaXpzMs7vvPOObr31Vl166aXy+/1asGCB3n///VHsbXpz+jcd9eGHH8rr9er6668f2Q6OEU7Hubu7W2vWrFFRUZF8Pp+uuOIKbdmyZZR6m76cjvPWrVs1d+5cjRs3Tvn5+XrggQfU0dExSr1NTzt27NCSJUs0ffp0eTwe/epXvzrvPqM+F5ox7Je//KXJysoyv/jFL8zevXvNI488YsaPH28+/fTTpNsfOHDAjBs3zjzyyCNm79695he/+IXJysoyb7311ij3PL04HedHHnnEPP300+ZPf/qT+fjjj01VVZXJysoyf/nLX0a55+nH6VhHHT161MyaNcuUlZWZuXPnjk5n01gq43znnXea+fPnm/r6enPw4EHzxz/+0Xz44Yej2Ov043ScGxoaTEZGhnnmmWfMgQMHTENDg7n22mvNXXfdNco9Ty91dXVmzZo15u233zaSzLvvvnvO7W3MhWM6jMybN89UVFTEtV199dXm8ccfT7r9D37wA3P11VfHtT344IPmxhtvHLE+jgVOxzmZa665xqxdu3a4uzbmpDrW5eXl5kc/+pF58sknCSND4HScf/vb35pAIGA6OjpGo3tjhtNx/ulPf2pmzZoV1/bss8+agoKCEevjWDOUMGJjLhyzh2l6enrU2NiosrKyuPaysjLt3Lkz6T67du1K2P62227T7t27dfr06RHrazpLZZzPFolE1NXVpcmTJ49EF8eMVMf6lVde0f79+/Xkk0+OdBfHhFTG+de//rVKSkr0k5/8RDNmzNBVV12lRx99VCdPnhyNLqelVMa5tLRUhw8fVl1dnYwx+vzzz/XWW2/pjjvuGI0uu4aNuTAtFspLRXt7u8LhsPLy8uLa8/Ly1NramnSf1tbWpNv39vaqvb1d+fn5I9bfdJXKOJ/t5z//uY4fP66lS5eORBfHjFTG+pNPPtHjjz+uhoYGeb1j9j/3YZXKOB84cEAffPCBcnJy9O6776q9vV0PPfSQvvjiC84bGUQq41xaWqqtW7eqvLxcp06dUm9vr+68804999xzo9Fl17AxF47ZykiUx+OJ+90Yk9B2vu2TtSOe03GOeuONN/TjH/9YtbW1uuyyy0aqe2PKUMc6HA7rnnvu0dq1a3XVVVeNVvfGDCd/05FIRB6PR1u3btW8efN0++23a/369Xr11VepjpyHk3Heu3evVq5cqSeeeEKNjY167733dPDgQVVUVIxGV11ltOfCMftVaerUqcrMzExI2G1tbQmJL2ratGlJt/d6vZoyZcqI9TWdpTLOUbW1tVq+fLnefPNN3XLLLSPZzTHB6Vh3dXVp9+7dampq0ne/+11JfZOmMUZer1fbtm3TzTffPCp9Tyep/E3n5+drxowZcUulz549W8YYHT58WFdeeeWI9jkdpTLO1dXVWrhwoR577DFJ0nXXXafx48dr0aJFeuqpp6heDxMbc+GYrYxkZ2crGAyqvr4+rr2+vl6lpaVJ91mwYEHC9tu2bVNJSYmysrJGrK/pLJVxlvoqIvfff79ef/11jvcOkdOx9vv9+uijj7Rnz57Yo6KiQl/+8pe1Z88ezZ8/f7S6nlZS+ZteuHChjhw5omPHjsXaPv74Y2VkZKigoGBE+5uuUhnnEydOKCMjftrKzMyUdOabOy6clblwxE6NvQhELxvbvHmz2bt3r1m1apUZP368+cc//mGMMebxxx839957b2z76OVMq1evNnv37jWbN2/m0t4hcDrOr7/+uvF6veaFF14wLS0tscfRo0dtfYS04XSsz8bVNEPjdJy7urpMQUGB+eY3v2n+9re/me3bt5srr7zSrFixwtZHSAtOx/mVV14xXq/X1NTUmP3795sPPvjAlJSUmHnz5tn6CGmhq6vLNDU1maamJiPJrF+/3jQ1NcUuob4Y5sIxHUaMMeaFF14wRUVFJjs729xwww1m+/btsefuu+8+87WvfS1u+9///vfmq1/9qsnOzjaXX3652bhx4yj3OD05Geevfe1rRlLC47777hv9jqchp3/TAxFGhs7pOO/bt8/ccsstJjc31xQUFJjKykpz4sSJUe51+nE6zs8++6y55pprTG5ursnPzzff+ta3zOHDh0e51+nld7/73Tn/n3sxzIUeY6htAQAAe8bsOSMAACA9EEYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABY9f8B1/3PCBNkPu8AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 查看新的AUC值\n",
    "# 预测不违约&违约概率\n",
    "y_pred_proba = model.predict_proba(X_test)\n",
    "y_pred_proba[:,1]  # 如果想单纯的查看违约概率，即查看y_pred_proba的第二列\n",
    "\n",
    "# 绘制ROC曲线，计算AUC值\n",
    "from sklearn.metrics import roc_curve\n",
    "fpr, tpr, thres = roc_curve(y_test, y_pred_proba[:,1])\n",
    "\n",
    "# 绘制ROC曲线\n",
    "import matplotlib.pyplot as plt\n",
    "plt.plot(fpr, tpr)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9880075960970136\n"
     ]
    }
   ],
   "source": [
    "# 计算AUC值\n",
    "from sklearn.metrics import roc_auc_score\n",
    "score = roc_auc_score(y_test, y_pred_proba[:,1])\n",
    "print(score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总结：原来获得的AUC值为0.9736，现在获得的AUC值为0.9877，的确提高了模型的预测水平"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.00292445, 0.37460102, 0.11992881, 0.19962956, 0.07719516,\n",
       "       0.22572099])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看此时的变量重要性\n",
    "model.feature_importances_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>特征名称</th>\n",
       "      <th>特征重要性</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>满意度</td>\n",
       "      <td>0.374601</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>工龄</td>\n",
       "      <td>0.225721</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>工程数量</td>\n",
       "      <td>0.199630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>考核得分</td>\n",
       "      <td>0.119929</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>月工时</td>\n",
       "      <td>0.077195</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>工资</td>\n",
       "      <td>0.002924</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   特征名称     特征重要性\n",
       "1   满意度  0.374601\n",
       "5    工龄  0.225721\n",
       "3  工程数量  0.199630\n",
       "2  考核得分  0.119929\n",
       "4   月工时  0.077195\n",
       "0    工资  0.002924"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 一一对应\n",
    "features = X.columns\n",
    "importances = model.feature_importances_\n",
    "\n",
    "# 通过表格形式显示\n",
    "importances_df = pd.DataFrame()  # 创建空二维表格，为之后准备\n",
    "importances_df['特征名称'] = features\n",
    "importances_df['特征重要性'] = importances\n",
    "\n",
    "importances_df.sort_values('特征重要性', ascending=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**决策树模型还有些别的超参数，如下所示：**\n",
    "\n",
    "下面是分类决策树模型DecisionTreeClassifier()模型常用的一些超参数及它们的解释：\n",
    "1. criterion：特征选择标准，取值为\"entropy\"信息熵和\"gini\"基尼系数，默认选择\"gini\"。\n",
    "2. splitter：取值为\"best\"和\"random\"，\"best\"在特征的所有划分点中找出最优的划分点，适合样本量不大的情况，\"random\"随机地在部分划分点中找局部最优的划分点，适合样本量非常大的情况，默认选择\"best\"。\n",
    "3. max_depth：决策树最大深度，取值为int或None，一般数据或特征比较少的时候可以不设置，如果数据或特征比较多时，可以设置最大深度进行限制。默认取‘None’。\n",
    "4. min_samples_split：子节点往下划分所需的最小样本数，默认取2，如果子节点中的样本数小于该值则停止分裂。\n",
    "5. min_samples_leaf：叶子节点的最少样本数，默认取1，如果小于该数值，该叶子节点会和兄弟节点一起被剪枝（即剔除该叶子节点和其兄弟节点，并停止分裂）。\n",
    "6. min_weight_fraction_leaf：叶子节点最小的样本权重和，默认取0，即不考虑权重问题，如果小于该数值，该叶子节点会和兄弟节点一起被剪枝（即剔除该叶子节点和其兄弟节点，并停止分裂）。如果较多样本有缺失值或者样本的分布类别偏差很大，则需考虑样本权重问题。\n",
    "7. max_features：在划分节点时所考虑的特征值数量的最大值，默认取None，可以传入int型或float型数据。如果是float型数据，表示百分数。\n",
    "8. max_leaf_nodes：最大叶子节点数，默认取None，可以传入int型数据。\n",
    "9. class_weight：指定类别权重，默认取None，可以取\"balanced\"，代表样本量少的类别所对应的样本权重更高，也可以传入字典指定权重。该参数主要是为防止训练集某些类别的样本过多，导致训练的决策树过于偏向这些类别。除了此处指定class_weight，还可以使用过采样和欠采样的方法处理样本类别不平衡的问题，过采样和欠采样将在第十一章：数据预处理讲解。\n",
    "10. random_state：当数据量较大，或特征变量较多时，可能在某个节点划分时，会碰上两个特征变量的信息熵增益或者基尼系数减少量是一样的情况，那么此时决策树模型默认是随机从中选一个特征变量进行划分，这样可能会导致每次运行程序后生成的决策树不太一致。如果设定random_state参数（如设置为123）可以保证每次运行代码时，各个节点的分裂结果都是一致的，这在特征变量较多，树的深度较深的时候较为重要。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**3.多参数调优**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'criterion': 'entropy', 'max_depth': 11, 'min_samples_split': 13}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "# 指定决策树分类器中各个参数的范围\n",
    "parameters = {'max_depth': [5, 7, 9, 11, 13], 'criterion':['gini', 'entropy'], 'min_samples_split':[5, 7, 9, 11, 13, 15]}\n",
    "# 构建决策树分类器\n",
    "model = DecisionTreeClassifier()  # 这里因为要进行参数调优，所以不需要传入固定的参数了\n",
    "\n",
    "# 网格搜索\n",
    "grid_search = GridSearchCV(model, parameters, scoring='roc_auc', cv=5)\n",
    "grid_search.fit(X_train, y_train)\n",
    "\n",
    "# 获得参数的最优值\n",
    "grid_search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9823333333333333\n"
     ]
    }
   ],
   "source": [
    "# 根据多参数调优的结果来重新搭建模型\n",
    "model = DecisionTreeClassifier(criterion='entropy', max_depth=11, min_samples_split=13)\n",
    "model.fit(X_train, y_train) \n",
    "\n",
    "# 查看整体预测准确度\n",
    "y_pred = model.predict(X_test)\n",
    "from sklearn.metrics import accuracy_score\n",
    "score = accuracy_score(y_pred, y_test)\n",
    "print(score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9890716883220967\n"
     ]
    }
   ],
   "source": [
    "# 查看新的AUC值\n",
    "# 预测不违约&违约概率\n",
    "y_pred_proba = model.predict_proba(X_test)\n",
    "y_pred_proba[:,1]  # 如果想单纯的查看违约概率，即查看y_pred_proba的第二列\n",
    "\n",
    "score = roc_auc_score(y_test, y_pred_proba[:,1])\n",
    "print(score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总结：这里多参数调优后发现，模型效果的确有所优化"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**注意点1：多参数调优和分别单参数调优的区别**\n",
    "\n",
    "多参数调优和单参数分别调优是有区别的，比如有的读者为了省事，对上面的3个参数进行3次单独的单参数调优，然后将结果汇总，这样的做法其实是不严谨的。因为在进行单参数调优的时候，是默认其他参数取默认值的，那么该参数和其他参数都不取默认值的情况就没有考虑进来，也即忽略了多个参数对模型的组合影响。以上面的代码示例来说，使用多参数调优时，它是5*2*6=60种组合可能，而如果是进行3次单参数调优，则只是5+2+6=13种组合可能。\n",
    "因此，如果只需要调节一个参数，那么可以使用单参数调优，如果需要调节多个参数，则推荐使用多参数调优。\n",
    "\n",
    "**注意点2：参数取值是给定范围的边界**\n",
    "\n",
    "另外一点需要需要注意的是，如果使用GridSearchCV()方法所得到的参数取值是给定范围的边界，那么有可能存在范围以外的取值使得模型效果更好，因此需要我们额外增加范围，继续调参。举例来说，倘若上述代码中获得的最佳max_depth值为设定的最大值13，那么实际真正合适的max_depth可能更大，此时便需要将搜索网格重新调整，如将max_depth的搜索范围变成[9, 11, 13, 15, 17]，再重新参数调优。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
